import torch
from einops import repeat, reduce

# repeat: 张量缩减，类似于tensorflow中的reduce操作，可以用于求平均值，最大最小值的同时压缩张量维度

# 维度压缩(减少维度)
x = torch.randn(3, 4, 5)
y = reduce(x, "i j k -> i j", "max")  # 最大值缩减 torch.Size([3, 4])

x = torch.randn(3, 4, 5)
y = reduce(x, "i j k -> i j", "mean")  # 平均值缩减 torch.Size([3, 4])

# 维度压缩(沿某维度压缩)
x = torch.randn(3, 4, 8)
y = reduce(x, "i j (n m)-> i j n", "mean", n=2)  # torch.Size([3, 4, 2])

pass
